Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused_rope forward op #54351

Merged
merged 11 commits into from
Jun 29, 2023
Merged

Conversation

AnnaTrainingG
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG commented Jun 5, 2023

PR types

Others

PR changes

Others

Description

Pcard-70458
Others

@paddle-bot
Copy link

paddle-bot bot commented Jun 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@AnnaTrainingG AnnaTrainingG force-pushed the fuse_broadcast branch 2 times, most recently from 0b1089d to 8e174d1 Compare June 13, 2023 06:45
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下单测

@@ -0,0 +1,160 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

融合的算子实现到phi/kernels/fusion/gpu目录下吧

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
T val[VecSize];
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为何不直接使用AlignedVector呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要设置ALL_BACKEND吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是删除后又加回来了?还有前向也是。

#include "paddle/phi/kernels/funcs/aligned_vector.h"

namespace phi {
template <typename T, int VecSize>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接使用AlignedVector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

int C,
int main_offset,
phi::Array<T*, 3> outs_data,
int break_iter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

break_iter -> num_inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

auto N = q.dims()[0];
auto H = q.dims()[1];
auto W = q.dims()[2];
auto C = q.dims()[3];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q是序列,它的四个维度含义分别是[batch_size, seq_len, num_heads, head_dim],用维度含义来命名变量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

@@ -621,6 +621,11 @@ def add(x, y, name=None):
return _elementwise_op(LayerHelper('elementwise_add', **locals()))


def fused_rope(q, k, v):
if in_dynamic_mode():
return _C_ops.fused_rope(q, k, v)
Copy link
Contributor

@Xreki Xreki Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_rope API应该加到paddle.incubate.nn.functional下面比较合适吧,API名使用完整的rotary_position_embedding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

@Xreki Xreki changed the title Add fused_fope forward op Add fused_rope forward op Jun 19, 2023
@AnnaTrainingG AnnaTrainingG force-pushed the fuse_broadcast branch 4 times, most recently from a7e40fd to 80c112d Compare June 20, 2023 04:16
YuanRisheng
YuanRisheng previously approved these changes Jun 27, 2023
@@ -422,6 +422,17 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)

- op : fused_rope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以放在fused_ops.yaml里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR 再改

@@ -459,5 +459,11 @@ void IndexAddGradInferMeta(const MetaTensor& index,
int axis,
MetaTensor* x_grad,
MetaTensor* add_tensor_grad);
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数按字典序放置

Copy link
Contributor Author

@AnnaTrainingG AnnaTrainingG Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -3489,5 +3489,33 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

@AnnaTrainingG AnnaTrainingG Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

namespace phi {

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse类型的kernel不用写头文件声明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR统一修改

namespace phi {

template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR统一修改

jzhang533
jzhang533 previously approved these changes Jun 28, 2023
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ZzSean
ZzSean previously approved these changes Jun 28, 2023
Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skipIf

zyfncg
zyfncg previously approved these changes Jun 28, 2023
@AnnaTrainingG AnnaTrainingG dismissed stale reviews from zyfncg, ZzSean, and jzhang533 via 7703138 June 28, 2023 06:14
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 一些review建议下个PR再改下

PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的N、C、H、W也统一改成实际含义吧,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有点不太理解dout_q是啥意思,是前向的计算结果out_q吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就是反向传递过来的dout

phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是删除后又加回来了?还有前向也是。

Fused rotary position embedding.

Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档里面加一下输入Tensor shape的描述吧。

Copy link
Contributor Author

@AnnaTrainingG AnnaTrainingG Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的, 下个PR 修改

indices = 1 / 10000 ** (indices / q.shape[3])
sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)

sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥一部分计算用Paddle API、一部分计算用Numpy API呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为要不一样呀

@sneaxiy sneaxiy merged commit a215c46 into PaddlePaddle:develop Jun 29, 2023
AnnaTrainingG added a commit to AnnaTrainingG/Paddle that referenced this pull request Aug 3, 2023
* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move
sneaxiy added a commit that referenced this pull request Aug 7, 2023
* Add fused_rope forward op (#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move

* Update the rope op according to the comments (#54985)

* Update multiary.cc

* Update __init__.py

* for int64_t and assert

* more

* remove useless assert first

---------

Co-authored-by: sneaxiy <sneaxiy@126.com>
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 20, 2023
…addlePaddle#55931)

* Add fused_rope forward op (PaddlePaddle#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move

* Update the rope op according to the comments (PaddlePaddle#54985)

* Update multiary.cc

* Update __init__.py

* for int64_t and assert

* more

* remove useless assert first

---------

Co-authored-by: sneaxiy <sneaxiy@126.com>
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 22, 2023
…addlePaddle#55931)

* Add fused_rope forward op (PaddlePaddle#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move

* Update the rope op according to the comments (PaddlePaddle#54985)

* Update multiary.cc

* Update __init__.py

* for int64_t and assert

* more

* remove useless assert first

---------

Co-authored-by: sneaxiy <sneaxiy@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants